!--------------------------------------------------------------------------------------------------!
!   CP2K: A general program to perform molecular dynamics simulations                              !
!   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
!                                                                                                  !
!   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
!--------------------------------------------------------------------------------------------------!

! **************************************************************************************************
!> \brief lower level routines for linear scaling SCF
!> \par History
!>       2010.10 created [Joost VandeVondele]
!> \author Joost VandeVondele
! **************************************************************************************************
MODULE dm_ls_scf_methods
   USE arnoldi_api,                     ONLY: arnoldi_extremal
   USE cp_dbcsr_api,                    ONLY: &
        dbcsr_add, dbcsr_add_on_diag, dbcsr_copy, dbcsr_create, dbcsr_desymmetrize, dbcsr_dot, &
        dbcsr_filter, dbcsr_finalize, dbcsr_frobenius_norm, dbcsr_get_data_type, &
        dbcsr_get_occupation, dbcsr_iterator_blocks_left, dbcsr_iterator_next_block, &
        dbcsr_iterator_start, dbcsr_iterator_stop, dbcsr_iterator_type, dbcsr_multiply, &
        dbcsr_put_block, dbcsr_release, dbcsr_scale, dbcsr_set, dbcsr_trace, dbcsr_type, &
        dbcsr_type_no_symmetry
   USE cp_log_handling,                 ONLY: cp_get_default_logger,&
                                              cp_logger_get_default_unit_nr,&
                                              cp_logger_type
   USE dm_ls_scf_qs,                    ONLY: matrix_qs_to_ls
   USE dm_ls_scf_types,                 ONLY: ls_cluster_atomic,&
                                              ls_mstruct_type,&
                                              ls_scf_env_type
   USE input_constants,                 ONLY: &
        ls_cluster_atomic, ls_s_preconditioner_atomic, ls_s_preconditioner_molecular, &
        ls_s_preconditioner_none, ls_s_sqrt_ns, ls_s_sqrt_proot, ls_scf_sign_ns, &
        ls_scf_sign_proot, ls_scf_sign_submatrix, ls_scf_submatrix_sign_direct_muadj, &
        ls_scf_submatrix_sign_direct_muadj_lowmem, ls_scf_submatrix_sign_ns
   USE iterate_matrix,                  ONLY: invert_Hotelling,&
                                              matrix_sign_Newton_Schulz,&
                                              matrix_sign_proot,&
                                              matrix_sign_submatrix,&
                                              matrix_sign_submatrix_mu_adjust,&
                                              matrix_sqrt_Newton_Schulz,&
                                              matrix_sqrt_proot
   USE kinds,                           ONLY: dp,&
                                              int_8
   USE machine,                         ONLY: m_flush,&
                                              m_walltime
   USE mathlib,                         ONLY: abnormal_value
#include "./base/base_uses.f90"

   IMPLICIT NONE

   PRIVATE

   CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dm_ls_scf_methods'

   PUBLIC :: ls_scf_init_matrix_S
   PUBLIC :: density_matrix_sign, density_matrix_sign_fixed_mu
   PUBLIC :: apply_matrix_preconditioner, compute_matrix_preconditioner
   PUBLIC :: density_matrix_trs4, density_matrix_tc2, compute_homo_lumo

CONTAINS

! **************************************************************************************************
!> \brief initialize S matrix related properties (sqrt, inverse...)
!>        Might be factored-out since this seems common code with the other SCF.
!> \param matrix_s ...
!> \param ls_scf_env ...
!> \par History
!>       2010.10 created [Joost VandeVondele]
!> \author Joost VandeVondele
! **************************************************************************************************
   SUBROUTINE ls_scf_init_matrix_S(matrix_s, ls_scf_env)
      TYPE(dbcsr_type)                                   :: matrix_s
      TYPE(ls_scf_env_type)                              :: ls_scf_env

      CHARACTER(len=*), PARAMETER :: routineN = 'ls_scf_init_matrix_S'

      INTEGER                                            :: handle, unit_nr
      REAL(KIND=dp)                                      :: frob_matrix, frob_matrix_base
      TYPE(cp_logger_type), POINTER                      :: logger
      TYPE(dbcsr_type)                                   :: matrix_tmp1, matrix_tmp2

      CALL timeset(routineN, handle)

      ! get a useful output_unit
      logger => cp_get_default_logger()
      IF (logger%para_env%is_source()) THEN
         unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
      ELSE
         unit_nr = -1
      END IF

      ! make our own copy of S
      IF (ls_scf_env%has_unit_metric) THEN
         CALL dbcsr_set(ls_scf_env%matrix_s, 0.0_dp)
         CALL dbcsr_add_on_diag(ls_scf_env%matrix_s, 1.0_dp)
      ELSE
         CALL matrix_qs_to_ls(ls_scf_env%matrix_s, matrix_s, ls_scf_env%ls_mstruct, covariant=.TRUE.)
      END IF

      CALL dbcsr_filter(ls_scf_env%matrix_s, ls_scf_env%eps_filter)

      ! needs a preconditioner for S
      IF (ls_scf_env%has_s_preconditioner) THEN
         CALL dbcsr_create(ls_scf_env%matrix_bs_sqrt, template=ls_scf_env%matrix_s, &
                           matrix_type=dbcsr_type_no_symmetry)
         CALL dbcsr_create(ls_scf_env%matrix_bs_sqrt_inv, template=ls_scf_env%matrix_s, &
                           matrix_type=dbcsr_type_no_symmetry)
         CALL compute_matrix_preconditioner(ls_scf_env%matrix_s, &
                                            ls_scf_env%s_preconditioner_type, ls_scf_env%ls_mstruct, &
                                            ls_scf_env%matrix_bs_sqrt, ls_scf_env%matrix_bs_sqrt_inv, &
                                            ls_scf_env%eps_filter, ls_scf_env%s_sqrt_order, &
                                            ls_scf_env%eps_lanczos, ls_scf_env%max_iter_lanczos)
      END IF

      ! precondition S
      IF (ls_scf_env%has_s_preconditioner) THEN
         CALL apply_matrix_preconditioner(ls_scf_env%matrix_s, "forward", &
                                          ls_scf_env%matrix_bs_sqrt, ls_scf_env%matrix_bs_sqrt_inv)
      END IF

      ! compute sqrt(S) and inv(sqrt(S))
      IF (ls_scf_env%use_s_sqrt) THEN

         CALL dbcsr_create(ls_scf_env%matrix_s_sqrt, template=ls_scf_env%matrix_s, &
                           matrix_type=dbcsr_type_no_symmetry)
         CALL dbcsr_create(ls_scf_env%matrix_s_sqrt_inv, template=ls_scf_env%matrix_s, &
                           matrix_type=dbcsr_type_no_symmetry)

         SELECT CASE (ls_scf_env%s_sqrt_method)
         CASE (ls_s_sqrt_proot)
            CALL matrix_sqrt_proot(ls_scf_env%matrix_s_sqrt, ls_scf_env%matrix_s_sqrt_inv, &
                                   ls_scf_env%matrix_s, ls_scf_env%eps_filter, &
                                   ls_scf_env%s_sqrt_order, &
                                   ls_scf_env%eps_lanczos, ls_scf_env%max_iter_lanczos, &
                                   symmetrize=.TRUE.)
         CASE (ls_s_sqrt_ns)
            CALL matrix_sqrt_Newton_Schulz(ls_scf_env%matrix_s_sqrt, ls_scf_env%matrix_s_sqrt_inv, &
                                           ls_scf_env%matrix_s, ls_scf_env%eps_filter, &
                                           ls_scf_env%s_sqrt_order, &
                                           ls_scf_env%eps_lanczos, ls_scf_env%max_iter_lanczos)
         CASE DEFAULT
            CPABORT("Unknown sqrt method.")
         END SELECT

         IF (ls_scf_env%check_s_inv) THEN
            CALL dbcsr_create(matrix_tmp1, template=ls_scf_env%matrix_s, &
                              matrix_type=dbcsr_type_no_symmetry)
            CALL dbcsr_create(matrix_tmp2, template=ls_scf_env%matrix_s, &
                              matrix_type=dbcsr_type_no_symmetry)

            CALL dbcsr_multiply("N", "N", 1.0_dp, ls_scf_env%matrix_s_sqrt_inv, ls_scf_env%matrix_s, &
                                0.0_dp, matrix_tmp1, filter_eps=ls_scf_env%eps_filter)

            CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_tmp1, ls_scf_env%matrix_s_sqrt_inv, &
                                0.0_dp, matrix_tmp2, filter_eps=ls_scf_env%eps_filter)

            frob_matrix_base = dbcsr_frobenius_norm(matrix_tmp2)
            CALL dbcsr_add_on_diag(matrix_tmp2, -1.0_dp)
            frob_matrix = dbcsr_frobenius_norm(matrix_tmp2)
            IF (unit_nr > 0) THEN
               WRITE (unit_nr, *) "Error for (inv(sqrt(S))*S*inv(sqrt(S))-I)", frob_matrix/frob_matrix_base
            END IF

            CALL dbcsr_release(matrix_tmp1)
            CALL dbcsr_release(matrix_tmp2)
         END IF
      END IF

      ! compute the inverse of S
      IF (ls_scf_env%needs_s_inv) THEN
         CALL dbcsr_create(ls_scf_env%matrix_s_inv, template=ls_scf_env%matrix_s, &
                           matrix_type=dbcsr_type_no_symmetry)
         IF (.NOT. ls_scf_env%use_s_sqrt) THEN
            CALL invert_Hotelling(ls_scf_env%matrix_s_inv, ls_scf_env%matrix_s, ls_scf_env%eps_filter)
         ELSE
            CALL dbcsr_multiply("N", "N", 1.0_dp, ls_scf_env%matrix_s_sqrt_inv, ls_scf_env%matrix_s_sqrt_inv, &
                                0.0_dp, ls_scf_env%matrix_s_inv, filter_eps=ls_scf_env%eps_filter)
         END IF
         IF (ls_scf_env%check_s_inv) THEN
            CALL dbcsr_create(matrix_tmp1, template=ls_scf_env%matrix_s, &
                              matrix_type=dbcsr_type_no_symmetry)
            CALL dbcsr_multiply("N", "N", 1.0_dp, ls_scf_env%matrix_s_inv, ls_scf_env%matrix_s, &
                                0.0_dp, matrix_tmp1, filter_eps=ls_scf_env%eps_filter)
            frob_matrix_base = dbcsr_frobenius_norm(matrix_tmp1)
            CALL dbcsr_add_on_diag(matrix_tmp1, -1.0_dp)
            frob_matrix = dbcsr_frobenius_norm(matrix_tmp1)
            IF (unit_nr > 0) THEN
               WRITE (unit_nr, *) "Error for (inv(S)*S-I)", frob_matrix/frob_matrix_base
            END IF
            CALL dbcsr_release(matrix_tmp1)
         END IF
      END IF

      CALL timestop(handle)
   END SUBROUTINE ls_scf_init_matrix_s

! **************************************************************************************************
!> \brief compute for a block positive definite matrix s (bs)
!>        the sqrt(bs) and inv(sqrt(bs))
!> \param matrix_s ...
!> \param preconditioner_type ...
!> \param ls_mstruct ...
!> \param matrix_bs_sqrt ...
!> \param matrix_bs_sqrt_inv ...
!> \param threshold ...
!> \param order ...
!> \param eps_lanczos ...
!> \param max_iter_lanczos ...
!> \par History
!>       2010.10 created [Joost VandeVondele]
!> \author Joost VandeVondele
! **************************************************************************************************
   SUBROUTINE compute_matrix_preconditioner(matrix_s, preconditioner_type, ls_mstruct, &
                                            matrix_bs_sqrt, matrix_bs_sqrt_inv, threshold, order, eps_lanczos, max_iter_lanczos)

      TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_s
      INTEGER                                            :: preconditioner_type
      TYPE(ls_mstruct_type)                              :: ls_mstruct
      TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_bs_sqrt, matrix_bs_sqrt_inv
      REAL(KIND=dp)                                      :: threshold
      INTEGER                                            :: order
      REAL(KIND=dp)                                      :: eps_lanczos
      INTEGER                                            :: max_iter_lanczos

      CHARACTER(LEN=*), PARAMETER :: routineN = 'compute_matrix_preconditioner'

      INTEGER                                            :: datatype, handle, iblock_col, iblock_row
      LOGICAL                                            :: block_needed
      REAL(dp), DIMENSION(:, :), POINTER                 :: block_dp
      TYPE(dbcsr_iterator_type)                          :: iter
      TYPE(dbcsr_type)                                   :: matrix_bs

      CALL timeset(routineN, handle)

      datatype = dbcsr_get_data_type(matrix_s) ! could be single or double precision

      ! first generate a block diagonal copy of s
      CALL dbcsr_create(matrix_bs, template=matrix_s)

      SELECT CASE (preconditioner_type)
      CASE (ls_s_preconditioner_none)
      CASE (ls_s_preconditioner_atomic, ls_s_preconditioner_molecular)
         CALL dbcsr_iterator_start(iter, matrix_s)
         DO WHILE (dbcsr_iterator_blocks_left(iter))
            CALL dbcsr_iterator_next_block(iter, iblock_row, iblock_col, block_dp)

            ! do we need the block ?
            ! this depends on the preconditioner, but also the matrix clustering method employed
            ! for a clustered matrix, right now, we assume that atomic and molecular preconditioners
            ! are actually the same, and only require that the diagonal blocks (clustered) are present

            block_needed = .FALSE.

            IF (iblock_row == iblock_col) THEN
               block_needed = .TRUE.
            ELSE
               IF (preconditioner_type == ls_s_preconditioner_molecular .AND. &
                   ls_mstruct%cluster_type == ls_cluster_atomic) THEN
                  IF (ls_mstruct%atom_to_molecule(iblock_row) == ls_mstruct%atom_to_molecule(iblock_col)) block_needed = .TRUE.
               END IF
            END IF

            ! add it
            IF (block_needed) THEN
               CALL dbcsr_put_block(matrix=matrix_bs, row=iblock_row, col=iblock_col, block=block_dp)
            END IF

         END DO
         CALL dbcsr_iterator_stop(iter)
      END SELECT

      CALL dbcsr_finalize(matrix_bs)

      SELECT CASE (preconditioner_type)
      CASE (ls_s_preconditioner_none)
         ! for now make it a simple identity matrix
         CALL dbcsr_copy(matrix_bs_sqrt, matrix_bs)
         CALL dbcsr_set(matrix_bs_sqrt, 0.0_dp)
         CALL dbcsr_add_on_diag(matrix_bs_sqrt, 1.0_dp)

         ! for now make it a simple identity matrix
         CALL dbcsr_copy(matrix_bs_sqrt_inv, matrix_bs)
         CALL dbcsr_set(matrix_bs_sqrt_inv, 0.0_dp)
         CALL dbcsr_add_on_diag(matrix_bs_sqrt_inv, 1.0_dp)
      CASE (ls_s_preconditioner_atomic, ls_s_preconditioner_molecular)
         CALL dbcsr_copy(matrix_bs_sqrt, matrix_bs)
         CALL dbcsr_copy(matrix_bs_sqrt_inv, matrix_bs)
         ! XXXXXXXXXXX
         ! XXXXXXXXXXX the threshold here could be done differently,
         ! XXXXXXXXXXX using eps_filter is reducing accuracy for no good reason, this is cheap
         ! XXXXXXXXXXX
         CALL matrix_sqrt_Newton_Schulz(matrix_bs_sqrt, matrix_bs_sqrt_inv, matrix_bs, &
                                        threshold=MIN(threshold, 1.0E-10_dp), order=order, &
                                        eps_lanczos=eps_lanczos, max_iter_lanczos=max_iter_lanczos)
      END SELECT

      CALL dbcsr_release(matrix_bs)

      CALL timestop(handle)

   END SUBROUTINE compute_matrix_preconditioner

! **************************************************************************************************
!> \brief apply a preconditioner either
!>        forward (precondition)            inv(sqrt(bs)) * A * inv(sqrt(bs))
!>        backward (restore to old form)        sqrt(bs)  * A * sqrt(bs)
!> \param matrix ...
!> \param direction ...
!> \param matrix_bs_sqrt ...
!> \param matrix_bs_sqrt_inv ...
!> \par History
!>       2010.10 created [Joost VandeVondele]
!> \author Joost VandeVondele
! **************************************************************************************************
   SUBROUTINE apply_matrix_preconditioner(matrix, direction, matrix_bs_sqrt, matrix_bs_sqrt_inv)

      TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix
      CHARACTER(LEN=*)                                   :: direction
      TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_bs_sqrt, matrix_bs_sqrt_inv

      CHARACTER(LEN=*), PARAMETER :: routineN = 'apply_matrix_preconditioner'

      INTEGER                                            :: handle
      TYPE(dbcsr_type)                                   :: matrix_tmp

      CALL timeset(routineN, handle)
      CALL dbcsr_create(matrix_tmp, template=matrix, matrix_type=dbcsr_type_no_symmetry)

      SELECT CASE (direction)
      CASE ("forward")
         CALL dbcsr_multiply("N", "N", 1.0_dp, matrix, matrix_bs_sqrt_inv, &
                             0.0_dp, matrix_tmp)
         CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_bs_sqrt_inv, matrix_tmp, &
                             0.0_dp, matrix)
      CASE ("backward")
         CALL dbcsr_multiply("N", "N", 1.0_dp, matrix, matrix_bs_sqrt, &
                             0.0_dp, matrix_tmp)
         CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_bs_sqrt, matrix_tmp, &
                             0.0_dp, matrix)
      CASE DEFAULT
         CPABORT("")
      END SELECT

      CALL dbcsr_release(matrix_tmp)

      CALL timestop(handle)

   END SUBROUTINE apply_matrix_preconditioner

! **************************************************************************************************
!> \brief compute the density matrix with a trace that is close to nelectron.
!>        take a mu as input, and improve by bisection as needed.
!> \param matrix_p ...
!> \param mu ...
!> \param fixed_mu ...
!> \param sign_method ...
!> \param sign_order ...
!> \param matrix_ks ...
!> \param matrix_s ...
!> \param matrix_s_inv ...
!> \param nelectron ...
!> \param threshold ...
!> \param sign_symmetric ...
!> \param submatrix_sign_method ...
!> \param matrix_s_sqrt_inv ...
!> \par History
!>       2010.10 created [Joost VandeVondele]
!>       2020.07 support for methods with internal mu adjustment [Michael Lass]
!> \author Joost VandeVondele
! **************************************************************************************************
   SUBROUTINE density_matrix_sign(matrix_p, mu, fixed_mu, sign_method, sign_order, matrix_ks, &
                                  matrix_s, matrix_s_inv, nelectron, threshold, sign_symmetric, submatrix_sign_method, &
                                  matrix_s_sqrt_inv)

      TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_p
      REAL(KIND=dp), INTENT(INOUT)                       :: mu
      LOGICAL                                            :: fixed_mu
      INTEGER                                            :: sign_method, sign_order
      TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_ks, matrix_s, matrix_s_inv
      INTEGER, INTENT(IN)                                :: nelectron
      REAL(KIND=dp), INTENT(IN)                          :: threshold
      LOGICAL, OPTIONAL                                  :: sign_symmetric
      INTEGER, OPTIONAL                                  :: submatrix_sign_method
      TYPE(dbcsr_type), INTENT(IN), OPTIONAL             :: matrix_s_sqrt_inv

      CHARACTER(LEN=*), PARAMETER :: routineN = 'density_matrix_sign'
      REAL(KIND=dp), PARAMETER                           :: initial_increment = 0.01_dp

      INTEGER                                            :: handle, iter, unit_nr, &
                                                            used_submatrix_sign_method
      LOGICAL                                            :: do_sign_symmetric, has_mu_high, &
                                                            has_mu_low, internal_mu_adjust
      REAL(KIND=dp)                                      :: increment, mu_high, mu_low, trace
      TYPE(cp_logger_type), POINTER                      :: logger

      CALL timeset(routineN, handle)

      logger => cp_get_default_logger()
      IF (logger%para_env%is_source()) THEN
         unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
      ELSE
         unit_nr = -1
      END IF

      do_sign_symmetric = .FALSE.
      IF (PRESENT(sign_symmetric)) do_sign_symmetric = sign_symmetric

      used_submatrix_sign_method = ls_scf_submatrix_sign_ns
      IF (PRESENT(submatrix_sign_method)) used_submatrix_sign_method = submatrix_sign_method

      internal_mu_adjust = ((sign_method .EQ. ls_scf_sign_submatrix) .AND. &
                            (used_submatrix_sign_method .EQ. ls_scf_submatrix_sign_direct_muadj .OR. &
                             used_submatrix_sign_method .EQ. ls_scf_submatrix_sign_direct_muadj_lowmem))

      IF (internal_mu_adjust) THEN
         CALL density_matrix_sign_internal_mu(matrix_p, trace, mu, sign_method, &
                                              matrix_ks, matrix_s, threshold, &
                                              used_submatrix_sign_method, &
                                              nelectron, matrix_s_sqrt_inv)
      ELSE
         increment = initial_increment

         has_mu_low = .FALSE.
         has_mu_high = .FALSE.

         ! bisect if both bounds are known, otherwise find the bounds with a linear search
         DO iter = 1, 30
            IF (has_mu_low .AND. has_mu_high) THEN
               mu = (mu_low + mu_high)/2
               IF (ABS(mu_high - mu_low) < threshold) EXIT
            END IF

            CALL density_matrix_sign_fixed_mu(matrix_p, trace, mu, sign_method, sign_order, &
                                              matrix_ks, matrix_s, matrix_s_inv, threshold, &
                                              do_sign_symmetric, used_submatrix_sign_method, &
                                              matrix_s_sqrt_inv)
            IF (unit_nr > 0) WRITE (unit_nr, '(T2,A,I2,1X,F13.9,1X,F15.9)') &
               "Density matrix:  iter, mu, trace error: ", iter, mu, trace - nelectron

            ! OK, we can skip early if we are as close as possible to the exact result
            ! smaller differences should be considered 'noise'
            IF (ABS(trace - nelectron) < 0.5_dp .OR. fixed_mu) EXIT

            IF (trace < nelectron) THEN
               mu_low = mu
               mu = mu + increment
               has_mu_low = .TRUE.
               increment = increment*2
            ELSE
               mu_high = mu
               mu = mu - increment
               has_mu_high = .TRUE.
               increment = increment*2
            END IF
         END DO

      END IF

      CALL timestop(handle)

   END SUBROUTINE density_matrix_sign

! **************************************************************************************************
!> \brief for a fixed mu, compute the corresponding density matrix and its trace
!> \param matrix_p ...
!> \param trace ...
!> \param mu ...
!> \param sign_method ...
!> \param sign_order ...
!> \param matrix_ks ...
!> \param matrix_s ...
!> \param matrix_s_inv ...
!> \param threshold ...
!> \param sign_symmetric ...
!> \param submatrix_sign_method ...
!> \param matrix_s_sqrt_inv ...
!> \par History
!>       2010.10 created [Joost VandeVondele]
!> \author Joost VandeVondele
! **************************************************************************************************
   SUBROUTINE density_matrix_sign_fixed_mu(matrix_p, trace, mu, sign_method, sign_order, matrix_ks, &
                                           matrix_s, matrix_s_inv, threshold, sign_symmetric, submatrix_sign_method, &
                                           matrix_s_sqrt_inv)

      TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_p
      REAL(KIND=dp), INTENT(OUT)                         :: trace
      REAL(KIND=dp), INTENT(INOUT)                       :: mu
      INTEGER                                            :: sign_method, sign_order
      TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_ks, matrix_s, matrix_s_inv
      REAL(KIND=dp), INTENT(IN)                          :: threshold
      LOGICAL                                            :: sign_symmetric
      INTEGER                                            :: submatrix_sign_method
      TYPE(dbcsr_type), INTENT(IN), OPTIONAL             :: matrix_s_sqrt_inv

      CHARACTER(LEN=*), PARAMETER :: routineN = 'density_matrix_sign_fixed_mu'

      INTEGER                                            :: handle, unit_nr
      REAL(KIND=dp)                                      :: frob_matrix
      TYPE(cp_logger_type), POINTER                      :: logger
      TYPE(dbcsr_type) :: matrix_p_ud, matrix_sign, matrix_sinv_ks, matrix_ssqrtinv_ks_ssqrtinv, &
         matrix_ssqrtinv_ks_ssqrtinv2, matrix_tmp

      CALL timeset(routineN, handle)

      logger => cp_get_default_logger()
      IF (logger%para_env%is_source()) THEN
         unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
      ELSE
         unit_nr = -1
      END IF

      CALL dbcsr_create(matrix_sign, template=matrix_s, matrix_type=dbcsr_type_no_symmetry)

      IF (sign_symmetric) THEN

         IF (.NOT. PRESENT(matrix_s_sqrt_inv)) &
            CPABORT("Argument matrix_s_sqrt_inv required if sign_symmetric is set")

         CALL dbcsr_create(matrix_ssqrtinv_ks_ssqrtinv, template=matrix_s, matrix_type=dbcsr_type_no_symmetry)
         CALL dbcsr_create(matrix_ssqrtinv_ks_ssqrtinv2, template=matrix_s, matrix_type=dbcsr_type_no_symmetry)
         CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_s_sqrt_inv, matrix_ks, &
                             0.0_dp, matrix_ssqrtinv_ks_ssqrtinv2, filter_eps=threshold)
         CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_ssqrtinv_ks_ssqrtinv2, matrix_s_sqrt_inv, &
                             0.0_dp, matrix_ssqrtinv_ks_ssqrtinv, filter_eps=threshold)
         CALL dbcsr_add_on_diag(matrix_ssqrtinv_ks_ssqrtinv, -mu)

         SELECT CASE (sign_method)
         CASE (ls_scf_sign_ns)
            CALL matrix_sign_Newton_Schulz(matrix_sign, matrix_ssqrtinv_ks_ssqrtinv, threshold, sign_order)
         CASE (ls_scf_sign_proot)
            CALL matrix_sign_proot(matrix_sign, matrix_ssqrtinv_ks_ssqrtinv, threshold, sign_order)
         CASE (ls_scf_sign_submatrix)
            CALL matrix_sign_submatrix(matrix_sign, matrix_ssqrtinv_ks_ssqrtinv, threshold, sign_order, submatrix_sign_method)
         CASE DEFAULT
            CPABORT("Unkown sign method.")
         END SELECT
         CALL dbcsr_release(matrix_ssqrtinv_ks_ssqrtinv)
         CALL dbcsr_release(matrix_ssqrtinv_ks_ssqrtinv2)

      ELSE ! .NOT. sign_symmetric
         ! get inv(S)*H-I*mu
         CALL dbcsr_create(matrix_sinv_ks, template=matrix_s, matrix_type=dbcsr_type_no_symmetry)
         CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_s_inv, matrix_ks, &
                             0.0_dp, matrix_sinv_ks, filter_eps=threshold)
         CALL dbcsr_add_on_diag(matrix_sinv_ks, -mu)

         ! compute sign(inv(S)*H-I*mu)
         SELECT CASE (sign_method)
         CASE (ls_scf_sign_ns)
            CALL matrix_sign_Newton_Schulz(matrix_sign, matrix_sinv_ks, threshold, sign_order)
         CASE (ls_scf_sign_proot)
            CALL matrix_sign_proot(matrix_sign, matrix_sinv_ks, threshold, sign_order)
         CASE (ls_scf_sign_submatrix)
            CALL matrix_sign_submatrix(matrix_sign, matrix_sinv_ks, threshold, sign_order, submatrix_sign_method)
         CASE DEFAULT
            CPABORT("Unkown sign method.")
         END SELECT
         CALL dbcsr_release(matrix_sinv_ks)
      END IF

      ! now construct the density matrix PS=0.5*(I-sign(inv(S)H-I*mu))
      CALL dbcsr_create(matrix_p_ud, template=matrix_s, matrix_type=dbcsr_type_no_symmetry)
      CALL dbcsr_copy(matrix_p_ud, matrix_sign)
      CALL dbcsr_scale(matrix_p_ud, -0.5_dp)
      CALL dbcsr_add_on_diag(matrix_p_ud, 0.5_dp)
      CALL dbcsr_release(matrix_sign)

      ! we now have PS, lets get its trace
      CALL dbcsr_trace(matrix_p_ud, trace)

      ! we can also check it is idempotent PS*PS=PS
      CALL dbcsr_create(matrix_tmp, template=matrix_s, matrix_type=dbcsr_type_no_symmetry)
      CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_p_ud, matrix_p_ud, &
                          0.0_dp, matrix_tmp, filter_eps=threshold)
      CALL dbcsr_add(matrix_tmp, matrix_p_ud, 1.0_dp, -1.0_dp)
      frob_matrix = dbcsr_frobenius_norm(matrix_tmp)
      IF (unit_nr > 0) WRITE (unit_nr, '(T2,A,F20.12)') "Deviation from idempotency: ", frob_matrix

      IF (sign_symmetric) THEN
         CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_s_sqrt_inv, matrix_p_ud, &
                             0.0_dp, matrix_tmp, filter_eps=threshold)
         CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_tmp, matrix_s_sqrt_inv, &
                             0.0_dp, matrix_p, filter_eps=threshold)
      ELSE

         ! get P=PS*inv(S)
         CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_p_ud, matrix_s_inv, &
                             0.0_dp, matrix_p, filter_eps=threshold)
      END IF
      CALL dbcsr_release(matrix_p_ud)
      CALL dbcsr_release(matrix_tmp)

      CALL timestop(handle)

   END SUBROUTINE density_matrix_sign_fixed_mu

! **************************************************************************************************
!> \brief compute the corresponding density matrix and its trace, using methods with internal mu adjustment
!> \param matrix_p ...
!> \param trace ...
!> \param mu ...
!> \param sign_method ...
!> \param matrix_ks ...
!> \param matrix_s ...
!> \param threshold ...
!> \param submatrix_sign_method ...
!> \param nelectron ...
!> \param matrix_s_sqrt_inv ...
!> \par History
!>       2020.07 created, based on density_matrix_sign_fixed_mu [Michael Lass]
!> \author Michael Lass
! **************************************************************************************************
   SUBROUTINE density_matrix_sign_internal_mu(matrix_p, trace, mu, sign_method, matrix_ks, &
                                              matrix_s, threshold, submatrix_sign_method, &
                                              nelectron, matrix_s_sqrt_inv)

      TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_p
      REAL(KIND=dp), INTENT(OUT)                         :: trace
      REAL(KIND=dp), INTENT(INOUT)                       :: mu
      INTEGER                                            :: sign_method
      TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_ks, matrix_s
      REAL(KIND=dp), INTENT(IN)                          :: threshold
      INTEGER                                            :: submatrix_sign_method
      INTEGER, INTENT(IN)                                :: nelectron
      TYPE(dbcsr_type), INTENT(IN)                       :: matrix_s_sqrt_inv

      CHARACTER(LEN=*), PARAMETER :: routineN = 'density_matrix_sign_internal_mu'

      INTEGER                                            :: handle, unit_nr
      REAL(KIND=dp)                                      :: frob_matrix
      TYPE(cp_logger_type), POINTER                      :: logger
      TYPE(dbcsr_type)                                   :: matrix_p_ud, matrix_sign, &
                                                            matrix_ssqrtinv_ks_ssqrtinv, &
                                                            matrix_ssqrtinv_ks_ssqrtinv2, &
                                                            matrix_tmp

      CALL timeset(routineN, handle)

      logger => cp_get_default_logger()
      IF (logger%para_env%is_source()) THEN
         unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
      ELSE
         unit_nr = -1
      END IF

      CALL dbcsr_create(matrix_sign, template=matrix_s, matrix_type=dbcsr_type_no_symmetry)

      CALL dbcsr_create(matrix_ssqrtinv_ks_ssqrtinv, template=matrix_s, matrix_type=dbcsr_type_no_symmetry)
      CALL dbcsr_create(matrix_ssqrtinv_ks_ssqrtinv2, template=matrix_s, matrix_type=dbcsr_type_no_symmetry)
      CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_s_sqrt_inv, matrix_ks, &
                          0.0_dp, matrix_ssqrtinv_ks_ssqrtinv2, filter_eps=threshold)
      CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_ssqrtinv_ks_ssqrtinv2, matrix_s_sqrt_inv, &
                          0.0_dp, matrix_ssqrtinv_ks_ssqrtinv, filter_eps=threshold)
      CALL dbcsr_add_on_diag(matrix_ssqrtinv_ks_ssqrtinv, -mu)

      SELECT CASE (sign_method)
      CASE (ls_scf_sign_submatrix)
         SELECT CASE (submatrix_sign_method)
         CASE (ls_scf_submatrix_sign_direct_muadj, ls_scf_submatrix_sign_direct_muadj_lowmem)
            CALL matrix_sign_submatrix_mu_adjust(matrix_sign, matrix_ssqrtinv_ks_ssqrtinv, mu, nelectron, threshold, &
                                                 submatrix_sign_method)
         CASE DEFAULT
            CPABORT("density_matrix_sign_internal_mu called with invalid submatrix sign method")
         END SELECT
      CASE DEFAULT
         CPABORT("density_matrix_sign_internal_mu called with invalid sign method.")
      END SELECT
      CALL dbcsr_release(matrix_ssqrtinv_ks_ssqrtinv)
      CALL dbcsr_release(matrix_ssqrtinv_ks_ssqrtinv2)

      ! now construct the density matrix PS=0.5*(I-sign(inv(S)H-I*mu))
      CALL dbcsr_create(matrix_p_ud, template=matrix_s, matrix_type=dbcsr_type_no_symmetry)
      CALL dbcsr_copy(matrix_p_ud, matrix_sign)
      CALL dbcsr_scale(matrix_p_ud, -0.5_dp)
      CALL dbcsr_add_on_diag(matrix_p_ud, 0.5_dp)
      CALL dbcsr_release(matrix_sign)

      ! we now have PS, lets get its trace
      CALL dbcsr_trace(matrix_p_ud, trace)

      ! we can also check it is idempotent PS*PS=PS
      CALL dbcsr_create(matrix_tmp, template=matrix_s, matrix_type=dbcsr_type_no_symmetry)
      CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_p_ud, matrix_p_ud, &
                          0.0_dp, matrix_tmp, filter_eps=threshold)
      CALL dbcsr_add(matrix_tmp, matrix_p_ud, 1.0_dp, -1.0_dp)
      frob_matrix = dbcsr_frobenius_norm(matrix_tmp)
      IF (unit_nr > 0) WRITE (unit_nr, '(T2,A,F20.12)') "Deviation from idempotency: ", frob_matrix

      CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_s_sqrt_inv, matrix_p_ud, &
                          0.0_dp, matrix_tmp, filter_eps=threshold)
      CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_tmp, matrix_s_sqrt_inv, &
                          0.0_dp, matrix_p, filter_eps=threshold)
      CALL dbcsr_release(matrix_p_ud)
      CALL dbcsr_release(matrix_tmp)

      CALL timestop(handle)

   END SUBROUTINE density_matrix_sign_internal_mu

! **************************************************************************************************
!> \brief compute the density matrix using a trace-resetting algorithm
!> \param matrix_p ...
!> \param matrix_ks ...
!> \param matrix_s_sqrt_inv ...
!> \param nelectron ...
!> \param threshold ...
!> \param e_homo ...
!> \param e_lumo ...
!> \param e_mu ...
!> \param dynamic_threshold ...
!> \param matrix_ks_deviation ...
!> \param max_iter_lanczos ...
!> \param eps_lanczos ...
!> \param converged ...
!> \par History
!>       2012.06 created [Florian Thoele]
!> \author Florian Thoele
! **************************************************************************************************
   SUBROUTINE density_matrix_trs4(matrix_p, matrix_ks, matrix_s_sqrt_inv, &
                                  nelectron, threshold, e_homo, e_lumo, e_mu, &
                                  dynamic_threshold, matrix_ks_deviation, &
                                  max_iter_lanczos, eps_lanczos, converged)

      TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_p
      TYPE(dbcsr_type), INTENT(IN)                       :: matrix_ks, matrix_s_sqrt_inv
      INTEGER, INTENT(IN)                                :: nelectron
      REAL(KIND=dp), INTENT(IN)                          :: threshold
      REAL(KIND=dp), INTENT(INOUT)                       :: e_homo, e_lumo, e_mu
      LOGICAL, INTENT(IN), OPTIONAL                      :: dynamic_threshold
      TYPE(dbcsr_type), INTENT(INOUT), OPTIONAL          :: matrix_ks_deviation
      INTEGER, INTENT(IN)                                :: max_iter_lanczos
      REAL(KIND=dp), INTENT(IN)                          :: eps_lanczos
      LOGICAL, INTENT(OUT), OPTIONAL                     :: converged

      CHARACTER(LEN=*), PARAMETER :: routineN = 'density_matrix_trs4'
      INTEGER, PARAMETER                                 :: max_iter = 100
      REAL(KIND=dp), PARAMETER                           :: gamma_max = 6.0_dp, gamma_min = 0.0_dp

      INTEGER                                            :: branch, estimated_steps, handle, i, j, &
                                                            unit_nr
      INTEGER(kind=int_8)                                :: flop1, flop2
      LOGICAL                                            :: arnoldi_converged, do_dyn_threshold
      REAL(KIND=dp) :: current_threshold, delta_n, eps_max, eps_min, est_threshold, frob_id, &
         frob_x, gam, homo, lumo, max_eig, max_threshold, maxdev, maxev, min_eig, minev, mmin, mu, &
         mu_a, mu_b, mu_c, mu_fa, mu_fc, occ_matrix, scaled_homo_bound, scaled_lumo_bound, t1, t2, &
         trace_fx, trace_gx, xi
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: gamma_values
      TYPE(cp_logger_type), POINTER                      :: logger
      TYPE(dbcsr_type)                                   :: matrix_k0, matrix_x, matrix_x_nosym, &
                                                            matrix_xidsq, matrix_xsq, tmp_gx

      IF (nelectron == 0) THEN
         CALL dbcsr_set(matrix_p, 0.0_dp)
         RETURN
      END IF

      CALL timeset(routineN, handle)

      logger => cp_get_default_logger()
      IF (logger%para_env%is_source()) THEN
         unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
      ELSE
         unit_nr = -1
      END IF

      do_dyn_threshold = .FALSE.
      IF (PRESENT(dynamic_threshold)) do_dyn_threshold = dynamic_threshold

      IF (PRESENT(converged)) converged = .FALSE.

      ! init X = (eps_n*I - H)/(eps_n - eps_0)  ... H* = S^-1/2*H*S^-1/2
      CALL dbcsr_create(matrix_x, template=matrix_ks, matrix_type="S")

      ! at some points the non-symmetric version of x is required
      CALL dbcsr_create(matrix_x_nosym, template=matrix_ks, matrix_type=dbcsr_type_no_symmetry)

      CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_s_sqrt_inv, matrix_ks, &
                          0.0_dp, matrix_x_nosym, filter_eps=threshold)
      CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_x_nosym, matrix_s_sqrt_inv, &
                          0.0_dp, matrix_x, filter_eps=threshold)
      CALL dbcsr_desymmetrize(matrix_x, matrix_x_nosym)

      CALL dbcsr_create(matrix_k0, template=matrix_ks, matrix_type=dbcsr_type_no_symmetry)
      CALL dbcsr_copy(matrix_k0, matrix_x_nosym)

      ! compute the deviation in the mixed matrix, as seen in the ortho basis
      IF (do_dyn_threshold) THEN
         CPASSERT(PRESENT(matrix_ks_deviation))
         CALL dbcsr_add(matrix_ks_deviation, matrix_x_nosym, -1.0_dp, 1.0_dp)
         CALL arnoldi_extremal(matrix_ks_deviation, maxev, minev, max_iter=max_iter_lanczos, threshold=eps_lanczos, &
                               converged=arnoldi_converged)
         maxdev = MAX(ABS(maxev), ABS(minev))
         IF (unit_nr > 0) THEN
            WRITE (unit_nr, '(T6,A,1X,L12)') "Lanczos converged:      ", arnoldi_converged
            WRITE (unit_nr, '(T6,A,1X,F12.5)') "change in mixed matrix: ", maxdev
            WRITE (unit_nr, '(T6,A,1X,F12.5)') "HOMO upper bound:       ", e_homo + maxdev
            WRITE (unit_nr, '(T6,A,1X,F12.5)') "LUMO lower bound:       ", e_lumo - maxdev
            WRITE (unit_nr, '(T6,A,1X,L12)') "Predicts a gap ?        ", ((e_lumo - maxdev) - (e_homo + maxdev)) > 0
         END IF
         ! save the old mixed matrix
         CALL dbcsr_copy(matrix_ks_deviation, matrix_x_nosym)

      END IF

      ! get largest/smallest eigenvalues for scaling
      CALL arnoldi_extremal(matrix_x_nosym, max_eig, min_eig, max_iter=max_iter_lanczos, threshold=eps_lanczos, &
                            converged=arnoldi_converged)
      IF (unit_nr > 0) WRITE (unit_nr, '(T6,A,1X,2F12.5,1X,A,1X,L1)') "Est. extremal eigenvalues", &
         min_eig, max_eig, " converged: ", arnoldi_converged
      eps_max = max_eig
      eps_min = min_eig

      ! scale KS matrix
      IF (eps_max == eps_min) THEN
         CALL dbcsr_scale(matrix_x, 1.0_dp/eps_max)
      ELSE
         CALL dbcsr_add_on_diag(matrix_x, -eps_max)
         CALL dbcsr_scale(matrix_x, -1.0_dp/(eps_max - eps_min))
      END IF

      current_threshold = threshold
      IF (do_dyn_threshold) THEN
         ! scale bounds for HOMO/LUMO
         scaled_homo_bound = (eps_max - (e_homo + maxdev))/(eps_max - eps_min)
         scaled_lumo_bound = (eps_max - (e_lumo - maxdev))/(eps_max - eps_min)
      END IF

      CALL dbcsr_create(matrix_xsq, template=matrix_ks, matrix_type="S")

      CALL dbcsr_create(matrix_xidsq, template=matrix_ks, matrix_type="S")

      CALL dbcsr_create(tmp_gx, template=matrix_ks, matrix_type="S")

      ALLOCATE (gamma_values(max_iter))

      DO i = 1, max_iter
         t1 = m_walltime()
         flop1 = 0; flop2 = 0

         ! get X*X
         CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_x, matrix_x, &
                             0.0_dp, matrix_xsq, &
                             filter_eps=current_threshold, flop=flop1)

         ! intermediate use matrix_xidsq to compute = X*X-X
         CALL dbcsr_copy(matrix_xidsq, matrix_x)
         CALL dbcsr_add(matrix_xidsq, matrix_xsq, -1.0_dp, 1.0_dp)
         frob_id = dbcsr_frobenius_norm(matrix_xidsq)
         frob_x = dbcsr_frobenius_norm(matrix_x)

         ! xidsq = (1-X)*(1-X)
         ! use (1-x)*(1-x) = 1 + x*x - 2*x
         CALL dbcsr_copy(matrix_xidsq, matrix_x)
         CALL dbcsr_add(matrix_xidsq, matrix_xsq, -2.0_dp, 1.0_dp)
         CALL dbcsr_add_on_diag(matrix_xidsq, 1.0_dp)

         ! tmp_gx = 4X-3X*X
         CALL dbcsr_copy(tmp_gx, matrix_x)
         CALL dbcsr_add(tmp_gx, matrix_xsq, 4.0_dp, -3.0_dp)

         ! get gamma
         ! Tr(F) = Tr(XX*tmp_gx) Tr(G) is equivalent
         CALL dbcsr_dot(matrix_xsq, matrix_xidsq, trace_gx)
         CALL dbcsr_dot(matrix_xsq, tmp_gx, trace_fx)

         ! if converged, and gam becomes noisy, fix it to 3, which results in a final McWeeny step.
         ! do this only if the electron count is reasonable.
         ! maybe tune if the current criterion is not good enough
         delta_n = nelectron - trace_fx
         ! condition: ABS(frob_id/frob_x) < SQRT(threshold) ...
         IF (((frob_id*frob_id) < (threshold*frob_x*frob_x)) .AND. (ABS(delta_n) < 0.5_dp)) THEN
            gam = 3.0_dp
         ELSE IF (ABS(delta_n) < 1e-14_dp) THEN
            gam = 0.0_dp ! rare case of perfect electron count
         ELSE
            ! make sure, we don't divide by zero, as soon as gam is outside the interval gam_min,gam_max, it doesn't matter
            gam = delta_n/MAX(trace_gx, ABS(delta_n)/100)
         END IF
         gamma_values(i) = gam

         IF (unit_nr > 0 .AND. .FALSE.) THEN
            WRITE (unit_nr, *) "trace_fx", trace_fx, "trace_gx", trace_gx, "gam", gam, &
               "frob_id", frob_id, "conv", ABS(frob_id/frob_x)
         END IF

         IF (do_dyn_threshold) THEN
            ! quantities used for dynamic thresholding, when the estimated gap is larger than zero
            xi = (scaled_homo_bound - scaled_lumo_bound)
            IF (xi > 0.0_dp) THEN
               mmin = 0.5*(scaled_homo_bound + scaled_lumo_bound)
               max_threshold = ABS(1 - 2*mmin)*xi

               scaled_homo_bound = evaluate_trs4_polynomial(scaled_homo_bound, gamma_values(i:), 1)
               scaled_lumo_bound = evaluate_trs4_polynomial(scaled_lumo_bound, gamma_values(i:), 1)
               estimated_steps = estimate_steps(scaled_homo_bound, scaled_lumo_bound, threshold)

               est_threshold = (threshold/(estimated_steps + i + 1))*xi/(1 + threshold/(estimated_steps + i + 1))
               est_threshold = MIN(max_threshold, est_threshold)
               IF (i > 1) est_threshold = MAX(est_threshold, 0.1_dp*current_threshold)
               current_threshold = est_threshold
            ELSE
               current_threshold = threshold
            END IF
         END IF

         IF (gam > gamma_max) THEN
            ! Xn+1 = 2X-X*X
            CALL dbcsr_add(matrix_x, matrix_xsq, 2.0_dp, -1.0_dp)
            CALL dbcsr_filter(matrix_x, current_threshold)
            branch = 1
         ELSE IF (gam < gamma_min) THEN
            ! Xn+1 = X*X
            CALL dbcsr_copy(matrix_x, matrix_xsq)
            branch = 2
         ELSE
            ! Xn+1 = F(X) + gam*G(X)
            CALL dbcsr_add(tmp_gx, matrix_xidsq, 1.0_dp, gam)
            CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_xsq, tmp_gx, &
                                0.0_dp, matrix_x, &
                                flop=flop2, filter_eps=current_threshold)
            branch = 3
         END IF

         occ_matrix = dbcsr_get_occupation(matrix_x)
         t2 = m_walltime()
         IF (unit_nr > 0) THEN
            WRITE (unit_nr, &
                   '(T6,A,I3,1X,F10.8,E12.3,F12.3,F13.3,E12.3)') "TRS4 it ", &
               i, occ_matrix, ABS(trace_gx), t2 - t1, &
               (flop1 + flop2)/(1.0E6_dp*MAX(t2 - t1, 0.001_dp)), current_threshold
            CALL m_flush(unit_nr)
         END IF

         IF (abnormal_value(trace_gx)) &
            CPABORT("trace_gx is an abnormal value (NaN/Inf).")

         ! a branch of 1 or 2 appears to lead to a less accurate electron number count and premature exit
         ! if it turns out this does not exit because we get stuck in branch 1/2 for a reason we need to refine further
         ! condition: ABS(frob_id/frob_x) < SQRT(threshold) ...
         IF ((frob_id*frob_id) < (threshold*frob_x*frob_x) .AND. branch == 3 .AND. (ABS(delta_n) < 0.5_dp)) THEN
            IF (PRESENT(converged)) converged = .TRUE.
            EXIT
         END IF

      END DO

      occ_matrix = dbcsr_get_occupation(matrix_x)
      IF (unit_nr > 0) WRITE (unit_nr, '(T6,A,I3,1X,F10.8,E12.3)') 'Final TRS4 iteration  ', i, occ_matrix, ABS(trace_gx)

      ! free some memory
      CALL dbcsr_release(tmp_gx)
      CALL dbcsr_release(matrix_xsq)
      CALL dbcsr_release(matrix_xidsq)

      ! output to matrix_p, P = inv(S)^0.5 X inv(S)^0.5
      CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_x, matrix_s_sqrt_inv, &
                          0.0_dp, matrix_x_nosym, filter_eps=threshold)
      CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_s_sqrt_inv, matrix_x_nosym, &
                          0.0_dp, matrix_p, filter_eps=threshold)

      ! calculate the chemical potential by doing a bisection of fk(x0)-0.5, where fk is evaluated using the stored values for gamma
      ! E. Rubensson et al., Chem Phys Lett 432, 2006, 591-594
      mu_a = 0.0_dp; mu_b = 1.0_dp; 
      mu_fa = evaluate_trs4_polynomial(mu_a, gamma_values, i - 1) - 0.5_dp
      DO j = 1, 40
         mu_c = 0.5*(mu_a + mu_b)
         mu_fc = evaluate_trs4_polynomial(mu_c, gamma_values, i - 1) - 0.5_dp ! i-1 because in the last iteration, only convergence is checked
         IF (ABS(mu_fc) < 1.0E-6_dp .OR. (mu_b - mu_a)/2 < 1.0E-6_dp) EXIT !TODO: define threshold values

         IF (mu_fc*mu_fa > 0) THEN
            mu_a = mu_c
            mu_fa = mu_fc
         ELSE
            mu_b = mu_c
         END IF
      END DO
      mu = (eps_min - eps_max)*mu_c + eps_max
      DEALLOCATE (gamma_values)
      IF (unit_nr > 0) THEN
         WRITE (unit_nr, '(T6,A,1X,F12.5)') 'Chemical potential (mu): ', mu
      END IF
      e_mu = mu

      IF (do_dyn_threshold) THEN
         CALL dbcsr_desymmetrize(matrix_x, matrix_x_nosym)
         CALL compute_homo_lumo(matrix_k0, matrix_x_nosym, eps_min, eps_max, &
                                threshold, max_iter_lanczos, eps_lanczos, homo, lumo, unit_nr)
         e_homo = homo
         e_lumo = lumo
      END IF

      CALL dbcsr_release(matrix_x)
      CALL dbcsr_release(matrix_x_nosym)
      CALL dbcsr_release(matrix_k0)
      CALL timestop(handle)

   END SUBROUTINE density_matrix_trs4

! **************************************************************************************************
!> \brief compute the density matrix using a non monotonic trace conserving
!>  algorithm based on SIAM DOI. 10.1137/130911585.
!>       2014.04 created [Jonathan Mullin]
!> \param matrix_p ...
!> \param matrix_ks ...
!> \param matrix_s_sqrt_inv ...
!> \param nelectron ...
!> \param threshold ...
!> \param e_homo ...
!> \param e_lumo ...
!> \param non_monotonic ...
!> \param eps_lanczos ...
!> \param max_iter_lanczos ...
!> \author Jonathan Mullin
! **************************************************************************************************
   SUBROUTINE density_matrix_tc2(matrix_p, matrix_ks, matrix_s_sqrt_inv, &
                                 nelectron, threshold, e_homo, e_lumo, &
                                 non_monotonic, eps_lanczos, max_iter_lanczos)

      TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_p
      TYPE(dbcsr_type), INTENT(IN)                       :: matrix_ks, matrix_s_sqrt_inv
      INTEGER, INTENT(IN)                                :: nelectron
      REAL(KIND=dp), INTENT(IN)                          :: threshold
      REAL(KIND=dp), INTENT(INOUT)                       :: e_homo, e_lumo
      LOGICAL, INTENT(IN), OPTIONAL                      :: non_monotonic
      REAL(KIND=dp), INTENT(IN)                          :: eps_lanczos
      INTEGER, INTENT(IN)                                :: max_iter_lanczos

      CHARACTER(LEN=*), PARAMETER :: routineN = 'density_matrix_tc2'
      INTEGER, PARAMETER                                 :: max_iter = 100

      INTEGER                                            :: handle, i, j, k, unit_nr
      INTEGER(kind=int_8)                                :: flop1, flop2
      LOGICAL                                            :: converged, do_non_monotonic
      REAL(KIND=dp)                                      :: beta, betaB, eps_max, eps_min, gama, &
                                                            max_eig, min_eig, occ_matrix, t1, t2, &
                                                            trace_fx, trace_gx
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: alpha, lambda, nu, poly, wu, X, Y
      TYPE(cp_logger_type), POINTER                      :: logger
      TYPE(dbcsr_type)                                   :: matrix_tmp, matrix_x, matrix_xsq

      CALL timeset(routineN, handle)

      logger => cp_get_default_logger()
      IF (logger%para_env%is_source()) THEN
         unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
      ELSE
         unit_nr = -1
      END IF

      do_non_monotonic = .FALSE.
      IF (PRESENT(non_monotonic)) do_non_monotonic = non_monotonic

      ! init X = (eps_n*I - H)/(eps_n - eps_0)  ... H* = S^-1/2*H*S^-1/2
      CALL dbcsr_create(matrix_x, template=matrix_ks, matrix_type=dbcsr_type_no_symmetry)
      CALL dbcsr_create(matrix_xsq, template=matrix_ks, matrix_type=dbcsr_type_no_symmetry)

      CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_s_sqrt_inv, matrix_ks, &
                          0.0_dp, matrix_xsq, filter_eps=threshold)
      CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_xsq, matrix_s_sqrt_inv, &
                          0.0_dp, matrix_x, filter_eps=threshold)

      IF (unit_nr > 0) THEN
         WRITE (unit_nr, '(T6,A,1X,F12.5)') "HOMO upper bound:       ", e_homo
         WRITE (unit_nr, '(T6,A,1X,F12.5)') "LUMO lower bound:       ", e_lumo
         WRITE (unit_nr, '(T6,A,1X,L12)') "Predicts a gap ?        ", ((e_lumo) - (e_homo)) > 0
      END IF

      ! get largest/smallest eigenvalues for scaling
      CALL arnoldi_extremal(matrix_x, max_eig, min_eig, max_iter=max_iter_lanczos, threshold=eps_lanczos, &
                            converged=converged)
      IF (unit_nr > 0) WRITE (unit_nr, '(T6,A,1X,2F12.5,1X,A,1X,L1)') "Est. extremal eigenvalues", &
         min_eig, max_eig, " converged: ", converged

      eps_max = max_eig
      eps_min = min_eig

      ! scale KS matrix
      CALL dbcsr_scale(matrix_x, -1.0_dp)
      CALL dbcsr_add_on_diag(matrix_x, eps_max)
      CALL dbcsr_scale(matrix_x, 1/(eps_max - eps_min))

      CALL dbcsr_copy(matrix_xsq, matrix_x)

      CALL dbcsr_create(matrix_tmp, template=matrix_ks, matrix_type=dbcsr_type_no_symmetry)

      ALLOCATE (poly(max_iter))
      ALLOCATE (nu(max_iter))
      ALLOCATE (wu(max_iter))
      ALLOCATE (alpha(max_iter))
      ALLOCATE (X(4))
      ALLOCATE (Y(4))
      ALLOCATE (lambda(4))

! Controls over the non_monotonic bounds, First if low gap, bias slightly
      beta = (eps_max - ABS(e_lumo))/(eps_max - eps_min)
      betaB = (eps_max + ABS(e_homo))/(eps_max - eps_min)

      IF ((beta - betaB) < 0.005_dp) THEN
         beta = beta - 0.002_dp
         betaB = betaB + 0.002_dp
      END IF
! Check if input specifies to use monotonic bounds.
      IF (.NOT. do_non_monotonic) THEN
         beta = 0.0_dp
         betaB = 1.0_dp
      END IF
! initial SCF cycle has no reliable estimate of homo/lumo, force monotinic bounds.
      IF (e_homo == 0.0_dp) THEN
         beta = 0.0_dp
         BetaB = 1.0_dp
      END IF

      ! init to take true branch first
      trace_fx = nelectron
      trace_gx = 0

      DO i = 1, max_iter
         t1 = m_walltime()
         flop1 = 0; flop2 = 0

         IF (ABS(trace_fx - nelectron) <= ABS(trace_gx - nelectron)) THEN
! Xn+1 = (aX+ (1-a)I)^2
            poly(i) = 1.0_dp
            alpha(i) = 2.0_dp/(2.0_dp - beta)

            CALL dbcsr_scale(matrix_x, alpha(i))
            CALL dbcsr_add_on_diag(matrix_x, 1.0_dp - alpha(i))
            CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_x, matrix_x, &
                                0.0_dp, matrix_xsq, &
                                filter_eps=threshold, flop=flop1)

!save X for control variables
            CALL dbcsr_copy(matrix_tmp, matrix_x)

            CALL dbcsr_copy(matrix_x, matrix_xsq)

            beta = (1.0_dp - alpha(i)) + alpha(i)*beta
            beta = beta*beta
            betaB = (1.0_dp - alpha(i)) + alpha(i)*betaB
            betaB = betaB*betaB
         ELSE
! Xn+1 = 2aX-a^2*X^2
            poly(i) = 0.0_dp
            alpha(i) = 2.0_dp/(1.0_dp + betaB)

            CALL dbcsr_scale(matrix_x, alpha(i))
            CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_x, matrix_x, &
                                0.0_dp, matrix_xsq, &
                                filter_eps=threshold, flop=flop1)

!save X for control variables
            CALL dbcsr_copy(matrix_tmp, matrix_x)
!
            CALL dbcsr_add(matrix_x, matrix_xsq, 2.0_dp, -1.0_dp)

            beta = alpha(i)*beta
            beta = 2.0_dp*beta - beta*beta
            betaB = alpha(i)*betaB
            betaB = 2.0_dp*betaB - betaB*betaB

         END IF
         occ_matrix = dbcsr_get_occupation(matrix_x)
         t2 = m_walltime()
         IF (unit_nr > 0) THEN
            WRITE (unit_nr, &
                   '(T6,A,I3,1X,F10.8,E12.3,F12.3,F13.3,E12.3)') "TC2 it ", &
               i, occ_matrix, t2 - t1, &
               (flop1 + flop2)/(1.0E6_dp*(t2 - t1)), threshold
            CALL m_flush(unit_nr)
         END IF

! calculate control terms
         CALL dbcsr_trace(matrix_xsq, trace_fx)

! intermediate use matrix_xsq compute X- X*X , temorarily use trace_gx
         CALL dbcsr_add(matrix_xsq, matrix_tmp, -1.0_dp, 1.0_dp)
         CALL dbcsr_trace(matrix_xsq, trace_gx)
         nu(i) = dbcsr_frobenius_norm(matrix_xsq)
         wu(i) = trace_gx

! intermediate use matrix_xsq to compute = 2X - X*X
         CALL dbcsr_add(matrix_xsq, matrix_tmp, 1.0_dp, 1.0_dp)
         CALL dbcsr_trace(matrix_xsq, trace_gx)
! TC2 has quadratic convergence, using the frobeniums norm as an idempotency deviation test.
         IF (ABS(nu(i)) < (threshold)) EXIT
      END DO

      occ_matrix = dbcsr_get_occupation(matrix_x)
      IF (unit_nr > 0) WRITE (unit_nr, '(T6,A,I3,1X,1F10.8,1X,1F10.8)') 'Final TC2 iteration  ', i, occ_matrix, ABS(nu(i))

      ! output to matrix_p, P = inv(S)^0.5 X inv(S)^0.5
      CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_x, matrix_s_sqrt_inv, &
                          0.0_dp, matrix_tmp, filter_eps=threshold)
      CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_s_sqrt_inv, matrix_tmp, &
                          0.0_dp, matrix_p, filter_eps=threshold)

      CALL dbcsr_release(matrix_xsq)
      CALL dbcsr_release(matrix_tmp)

      ! ALGO 3 from. SIAM DOI. 10.1137/130911585
      X(1) = 1.0_dp
      X(2) = 1.0_dp
      X(3) = 0.0_dp
      X(4) = 0.0_dp
      gama = 6.0_dp - 4.0_dp*(SQRT(2.0_dp))
      gama = gama - gama*gama
      DO WHILE (nu(i) < gama)
         ! safeguard against negative root, is skipping correct?
         IF (wu(i) < 1.0e-14_dp) THEN
            i = i - 1
            CYCLE
         END IF
         IF ((1.0_dp - 4.0_dp*nu(i)*nu(i)/wu(i)) < 0.0_dp) THEN
            i = i - 1
            CYCLE
         END IF
         Y(1) = 0.5_dp*(1.0_dp - SQRT(1.0_dp - 4.0_dp*nu(i)*nu(i)/wu(i)))
         Y(2) = 0.5_dp*(1.0_dp - SQRT(1.0_dp - 4.0_dp*nu(i)))
         Y(3) = 0.5_dp*(1.0_dp + SQRT(1.0_dp - 4.0_dp*nu(i)))
         Y(4) = 0.5_dp*(1.0_dp + SQRT(1.0_dp - 4.0_dp*nu(i)*nu(i)/wu(i)))
         Y(:) = MIN(1.0_dp, MAX(0.0_dp, Y(:)))
         DO j = i, 1, -1
            IF (poly(j) == 1.0_dp) THEN
               DO k = 1, 4
                  Y(k) = SQRT(Y(k))
                  Y(k) = (Y(k) - 1.0_dp + alpha(j))/alpha(j)
               END DO ! end K
            ELSE
               DO k = 1, 4
                  Y(k) = 1.0_dp - SQRT(1.0_dp - Y(k))
                  Y(k) = Y(k)/alpha(j)
               END DO ! end K
            END IF ! end poly
         END DO ! end j
         X(1) = MIN(X(1), Y(1))
         X(2) = MIN(X(2), Y(2))
         X(3) = MAX(X(3), Y(3))
         X(4) = MAX(X(4), Y(4))
         i = i - 1
      END DO ! end i
!   lambda 1,2,3,4 are:: out lumo, in lumo, in homo, out homo
      DO k = 1, 4
         lambda(k) = eps_max - (eps_max - eps_min)*X(k)
      END DO ! end k
! END  ALGO 3 from. SIAM DOI. 10.1137/130911585
      e_homo = lambda(4)
      e_lumo = lambda(1)
      IF (unit_nr > 0) WRITE (unit_nr, '(T6,A,3E12.4)') "outer homo/lumo/gap", e_homo, e_lumo, (e_lumo - e_homo)
      IF (unit_nr > 0) WRITE (unit_nr, '(T6,A,3E12.4)') "inner homo/lumo/gap", lambda(3), lambda(2), (lambda(2) - lambda(3))

      DEALLOCATE (poly)
      DEALLOCATE (nu)
      DEALLOCATE (wu)
      DEALLOCATE (alpha)
      DEALLOCATE (X)
      DEALLOCATE (Y)
      DEALLOCATE (lambda)

      CALL dbcsr_release(matrix_x)
      CALL timestop(handle)

   END SUBROUTINE density_matrix_tc2

! **************************************************************************************************
!> \brief compute the homo and lumo given a KS matrix and a density matrix in the orthonormalized basis
!>        and the eps_min and eps_max, min and max eigenvalue of the ks matrix
!> \param matrix_k ...
!> \param matrix_p ...
!> \param eps_min ...
!> \param eps_max ...
!> \param threshold ...
!> \param max_iter_lanczos ...
!> \param eps_lanczos ...
!> \param homo ...
!> \param lumo ...
!> \param unit_nr ...
!> \par History
!>       2012.06 created [Florian Thoele]
!> \author Florian Thoele
! **************************************************************************************************
   SUBROUTINE compute_homo_lumo(matrix_k, matrix_p, eps_min, eps_max, threshold, max_iter_lanczos, eps_lanczos, homo, lumo, unit_nr)
      TYPE(dbcsr_type)                                   :: matrix_k, matrix_p
      REAL(KIND=dp)                                      :: eps_min, eps_max, threshold
      INTEGER, INTENT(IN)                                :: max_iter_lanczos
      REAL(KIND=dp), INTENT(IN)                          :: eps_lanczos
      REAL(KIND=dp)                                      :: homo, lumo
      INTEGER                                            :: unit_nr

      LOGICAL                                            :: converged
      REAL(KIND=dp)                                      :: max_eig, min_eig, shift1, shift2
      TYPE(dbcsr_type)                                   :: tmp1, tmp2, tmp3

! temporary matrices used for HOMO/LUMO calculation

      CALL dbcsr_create(tmp1, template=matrix_k, matrix_type=dbcsr_type_no_symmetry)

      CALL dbcsr_create(tmp2, template=matrix_k, matrix_type=dbcsr_type_no_symmetry)

      CALL dbcsr_create(tmp3, template=matrix_k, matrix_type=dbcsr_type_no_symmetry)

      shift1 = -eps_min
      shift2 = eps_max

      ! find largest ev of P*(K+shift*1), where shift is the neg. val. of the smallest ev of K
      CALL dbcsr_copy(tmp2, matrix_k)
      CALL dbcsr_add_on_diag(tmp2, shift1)
      CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_p, tmp2, &
                          0.0_dp, tmp1, filter_eps=threshold)
      CALL arnoldi_extremal(tmp1, max_eig, min_eig, converged=converged, &
                            threshold=eps_lanczos, max_iter=max_iter_lanczos)
      homo = max_eig - shift1
      IF (unit_nr > 0) THEN
         WRITE (unit_nr, '(T6,A,1X,L12)') "Lanczos converged:      ", converged
      END IF

      ! -(1-P)*(K-shift*1) = (1-P)*(shift*1 - K), where shift is the largest ev of K
      CALL dbcsr_copy(tmp3, matrix_p)
      CALL dbcsr_scale(tmp3, -1.0_dp)
      CALL dbcsr_add_on_diag(tmp3, 1.0_dp) !tmp3 = 1-P
      CALL dbcsr_copy(tmp2, matrix_k)
      CALL dbcsr_add_on_diag(tmp2, -shift2)
      CALL dbcsr_multiply("N", "N", -1.0_dp, tmp3, tmp2, &
                          0.0_dp, tmp1, filter_eps=threshold)
      CALL arnoldi_extremal(tmp1, max_eig, min_eig, converged=converged, &
                            threshold=eps_lanczos, max_iter=max_iter_lanczos)
      lumo = -max_eig + shift2

      IF (unit_nr > 0) THEN
         WRITE (unit_nr, '(T6,A,1X,L12)') "Lanczos converged:      ", converged
         WRITE (unit_nr, '(T6,A,1X,3F12.5)') 'HOMO/LUMO/gap', homo, lumo, lumo - homo
      END IF
      CALL dbcsr_release(tmp1)
      CALL dbcsr_release(tmp2)
      CALL dbcsr_release(tmp3)

   END SUBROUTINE compute_homo_lumo

! **************************************************************************************************
!> \brief ...
!> \param x ...
!> \param gamma_values ...
!> \param i ...
!> \return ...
! **************************************************************************************************
   FUNCTION evaluate_trs4_polynomial(x, gamma_values, i) RESULT(xr)
      REAL(KIND=dp)                                      :: x
      REAL(KIND=dp), DIMENSION(:)                        :: gamma_values
      INTEGER                                            :: i
      REAL(KIND=dp)                                      :: xr

      REAL(KIND=dp), PARAMETER                           :: gam_max = 6.0_dp, gam_min = 0.0_dp

      INTEGER                                            :: k

      xr = x
      DO k = 1, i
         IF (gamma_values(k) > gam_max) THEN
            xr = 2*xr - xr**2
         ELSE IF (gamma_values(k) < gam_min) THEN
            xr = xr**2
         ELSE
            xr = (xr*xr)*(4*xr - 3*xr*xr) + gamma_values(k)*xr*xr*((1 - xr)**2)
         END IF
      END DO
   END FUNCTION evaluate_trs4_polynomial

! **************************************************************************************************
!> \brief ...
!> \param homo ...
!> \param lumo ...
!> \param threshold ...
!> \return ...
! **************************************************************************************************
   FUNCTION estimate_steps(homo, lumo, threshold) RESULT(steps)
      REAL(KIND=dp)                                      :: homo, lumo, threshold
      INTEGER                                            :: steps

      INTEGER                                            :: i
      REAL(KIND=dp)                                      :: h, l, m

      l = lumo
      h = homo

      DO i = 1, 200
         IF (ABS(l) < threshold .AND. ABS(1 - h) < threshold) EXIT
         m = 0.5_dp*(h + l)
         IF (m > 0.5_dp) THEN
            h = h**2
            l = l**2
         ELSE
            h = 2*h - h**2
            l = 2*l - l**2
         END IF
      END DO
      steps = i - 1
   END FUNCTION estimate_steps

END MODULE dm_ls_scf_methods
